# Reference implementation of MK-CAViT consistent with the Methodology section.
# - Three-path multi-kernel tokenization (3x3, 7x7/stride2, 15x15)
# - Fast-HGR attention (cosine matrix + trace-of-covariances term)
# - Gated small/mid fusion and adaptive mixing with large-scale context
# - Adaptive multi-head gating
# - Classification head and segmentation head
# The module exposes mk_cavit_tiny/small/base builders.

from typing import Tuple, Optional, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# -----------------------------
# helpers
# -----------------------------

class DropPath(nn.Module):
    """Stochastic depth drop-path (per sample)."""
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = torch.empty(shape, dtype=x.dtype, device=x.device).bernoulli_(keep)
        return x.div(keep) * mask


def _l2_normalize(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return x / (x.norm(dim=-1, keepdim=True) + eps)


def _batch_cov(x: torch.Tensor) -> torch.Tensor:
    """
    Compute covariance matrix per batch: cov = (X^T X) / (N-1), where
    X is mean-centered tokens (B, N, D).
    Returns: (B, D, D)
    """
    x = x - x.mean(dim=1, keepdim=True)  # center along tokens
    # (B, N, D) -> (B, D, N) @ (B, N, D) = (B, D, D)
    return torch.matmul(x.transpose(1, 2), x) / max(x.shape[1] - 1, 1)


def fast_hgr_logits(Q: torch.Tensor, K: torch.Tensor, lam: float) -> torch.Tensor:
    """
    Fast-HGR attention logits:
      logits = cosine(Q, K) + lam * tr( cov(Q) @ cov(K) )
    Where cosine(Q, K) is the all-pairs cosine similarity matrix.
    Args:
        Q, K: (B, Nq, D), (B, Nk, D)
        lam: scalar lambda
    Returns:
        (B, Nq, Nk) logits
    """
    Qn = _l2_normalize(Q)                       # (B, Nq, D)
    Kn = _l2_normalize(K)                       # (B, Nk, D)
    # All-pairs cosine sim = Q * K^T
    cos = torch.matmul(Qn, Kn.transpose(1, 2))  # (B, Nq, Nk)

    # Global covariance trace term, broadcast over (Nq, Nk)
    covQ = _batch_cov(Q)                        # (B, D, D)
    covK = _batch_cov(K)                        # (B, D, D)
    # batch trace of product
    prod = torch.matmul(covQ, covK)             # (B, D, D)
    tr = prod.diagonal(dim1=-2, dim2=-1).sum(-1)  # (B,)
    tr = tr.view(-1, 1, 1)                      # (B,1,1)
    return cos + lam * tr


# -----------------------------
# modules
# -----------------------------

class MultiScaleTokenizer(nn.Module):
    """
    Three parallel tokenizers:
      - small: 3x3, stride=2 (default ViT-like reduction)
      - mid:   7x7, stride=2  (stronger local pooling)
      - large: 15x15, stride=1 (keep resolution; captures global context)
    All branches output tensors in (B, N, D).
    """
    def __init__(self, in_ch: int, embed_dim: int, img_size: int = 224):
        super().__init__()
        # Small and mid reduce spatial size; large keeps it (same padding)
        self.s = nn.Conv2d(in_ch, embed_dim, kernel_size=3, stride=2, padding=1, bias=False)
        self.m = nn.Conv2d(in_ch, embed_dim, kernel_size=7, stride=2, padding=3, bias=False)
        self.l = nn.Conv2d(in_ch, embed_dim, kernel_size=15, stride=1, padding=7, bias=False)

        self.norm_s = nn.LayerNorm(embed_dim)
        self.norm_m = nn.LayerNorm(embed_dim)
        self.norm_l = nn.LayerNorm(embed_dim)

        self.img_size = img_size
        self.embed_dim = embed_dim

    def _to_tokens(self, x: torch.Tensor, conv: nn.Conv2d, norm: nn.LayerNorm) -> Tuple[torch.Tensor, Tuple[int, int]]:
        x = conv(x)                           # (B, C, H', W')
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)      # (B, N, C)
        x = norm(x)
        return x, (H, W)

    def forward(self, x: torch.Tensor) -> Dict[str, Tuple[torch.Tensor, Tuple[int, int]]]:
        s, s_hw = self._to_tokens(x, self.s, self.norm_s)
        m, m_hw = self._to_tokens(x, self.m, self.norm_m)
        l, l_hw = self._to_tokens(x, self.l, self.norm_l)
        return {'s': (s, s_hw), 'm': (m, m_hw), 'l': (l, l_hw)}


class FastHGRAttention(nn.Module):
    """
    Multi-head attention using Fast-HGR logits with adaptive head gating.
    Heads share the same Fast-HGR logits to keep the FLOPs modest; each head
    has its own output projection and a learnable gate g_h in (0,1).
    """
    def __init__(self, dim: int, num_heads: int = 4, qkv_bias: bool = True, attn_drop: float = 0.0,
                 proj_drop: float = 0.0, lam: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)

        self.gates = nn.Parameter(torch.zeros(num_heads))  # learnable per-head gates
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.lam = lam

    def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x_q: (B, Nq, D)
            x_kv: (B, Nk, D)
        Returns:
            (B, Nq, D)
        """
        B, Nq, D = x_q.shape
        Nk = x_kv.shape[1]

        q = self.q(x_q).view(B, Nq, self.num_heads, self.head_dim)  # (B,Nq,H,Hd)
        k = self.k(x_kv).view(B, Nk, self.num_heads, self.head_dim) # (B,Nk,H,Hd)
        v = self.v(x_kv).view(B, Nk, self.num_heads, self.head_dim) # (B,Nk,H,Hd)

        # Merge heads for logits computation using shared Fast-HGR (on concatenated head space)
        q_merge = q.reshape(B, Nq, D)  # (B,Nq,D)
        k_merge = k.reshape(B, Nk, D)  # (B,Nk,D)

        logits = fast_hgr_logits(q_merge, k_merge, self.lam) / math.sqrt(self.head_dim)
        attn = logits.softmax(dim=-1)                    # (B,Nq,Nk)
        attn = self.attn_drop(attn)

        # Apply attention per head using shared weights
        # (B,Nq,Nk) @ (B,Nk,H,Hd) -> (B,Nq,H,Hd)
        out = torch.einsum("bij,bjhd->bihd", attn, v)

        # Adaptive head gating
        gates = torch.sigmoid(self.gates).view(1, 1, self.num_heads, 1)
        out = out * gates

        out = out.reshape(B, Nq, D)                      # (B,Nq,D)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class MLP(nn.Module):
    def __init__(self, dim: int, hidden: int, drop: float = 0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    """A transformer-style block built around Fast-HGR attention."""
    def __init__(self, dim: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.0,
                 lam: float = 0.1):
        super().__init__()
        self.norm_q = nn.LayerNorm(dim)
        self.norm_kv = nn.LayerNorm(dim)
        self.attn = FastHGRAttention(dim, heads, lam=lam, proj_drop=drop, attn_drop=drop)
        self.drop_path = DropPath(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), drop)

    def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor) -> torch.Tensor:
        y = self.attn(self.norm_q(x_q), self.norm_kv(x_kv))
        x = x_q + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class MultiScaleFusion(nn.Module):
    """
    Two-stage fusion:
      1) local-mid fusion using cross Fast-HGR:
           A_sm = sig(alpha) * Attn(s <- m) + sig(beta) * Attn(m <- s)
      2) global mixing with large context:
           gamma from pooled A_sm mixes self-attn(A_sm) with V^l (via attn)
    """
    def __init__(self, dim: int, heads: int, lam: float = 0.1):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(dim))
        self.beta  = nn.Parameter(torch.zeros(dim))
        self.local_from_m = Block(dim, heads, lam=lam)  # s <- m
        self.local_from_s = Block(dim, heads, lam=lam)  # m <- s
        self.mix_self = Block(dim, heads, lam=lam)      # A_sm <- A_sm
        self.mix_with_L = Block(dim, heads, lam=lam)    # A_sm <- L
        self.gamma_proj = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, s: torch.Tensor, m: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        # s <- m and m <- s (cross)
        a = torch.sigmoid(self.alpha).view(1, 1, -1)
        b = torch.sigmoid(self.beta).view(1, 1, -1)
        s_new = self.local_from_m(s, m)
        m_new = self.local_from_s(m, s)
        a_sm = a * s_new + b * m_new

        # global context integration
        pooled = a_sm.mean(dim=1)  # (B,D)
        gamma = self.gamma_proj(pooled)  # (B,1)
        a_self = self.mix_self(a_sm, a_sm)  # self-mixing
        a_L    = self.mix_with_L(a_sm, l)  # attend into large branch
        out = gamma[:, None, :] * a_L + (1.0 - gamma[:, None, :]) * a_self
        return out


# -----------------------------
# main model
# -----------------------------

class MK_CAViT(nn.Module):
    """
    MK-CAViT backbone with:
      - Three-path tokenization
      - Multi-layer Fast-HGR fusion
      - Classification and segmentation heads
      - Optional F-HGR auxiliary loss between scales (enabled when labels are provided)
    """
    def __init__(
        self,
        num_classes: int = 1000,
        img_size: int = 224,
        in_chans: int = 3,
        embed_dim: int = 96,
        depth: int = 6,
        heads: int = 4,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        drop_path: float = 0.1,
        lam: float = 0.1
    ):
        super().__init__()
        self.num_classes = num_classes
        self.tokenizer = MultiScaleTokenizer(in_chans, embed_dim, img_size)
        self.pos_drop = nn.Dropout(drop)

        # simple learned 1D positions per branch (optional but helpful)
        self.pos_s = nn.Parameter(torch.zeros(1, (img_size//2)*(img_size//2), embed_dim))
        self.pos_m = nn.Parameter(torch.zeros(1, (img_size//2)*(img_size//2), embed_dim))
        self.pos_l = nn.Parameter(torch.zeros(1, (img_size)*(img_size), embed_dim))  # upper bound; sliced at runtime

        # stack of fusion blocks
        dpr = torch.linspace(0, drop_path, depth).tolist()
        self.blocks = nn.ModuleList([
            MultiScaleFusion(embed_dim, heads, lam=lam) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # segmentation head: simple 1x1 conv on small-branch map
        self.seg_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, num_classes, kernel_size=1)
        )

        nn.init.trunc_normal_(self.pos_s, std=0.02)
        nn.init.trunc_normal_(self.pos_m, std=0.02)
        nn.init.trunc_normal_(self.pos_l, std=0.02)

    # ---- utilities ----

    @staticmethod
    def _add_pos(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        # broadcast/crop to sequence length
        if pos.size(1) >= x.size(1):
            return x + pos[:, :x.size(1), :]
        else:
            # interpolate positional embeddings if needed (rare; safeguard)
            return x + F.interpolate(
                pos.transpose(1, 2),
                size=x.size(1),
                mode='linear',
                align_corners=False
            ).transpose(1, 2)

    # ---- forward paths ----

    def forward_backbone(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """
        Returns:
            fused tokens (B, N, D) and (Hs, Ws) size for small branch (for seg head)
        """
        branches = self.tokenizer(x)
        s, (Hs, Ws) = branches['s']
        m, _ = branches['m']
        l, (Hl, Wl) = branches['l']

        s = self._add_pos(s, self.pos_s)
        m = self._add_pos(m, self.pos_m)
        # For l, slice a pos map large enough; here we pre-sized to img_size^2
        l = self._add_pos(l, self.pos_l)

        s, m, l = self.pos_drop(s), self.pos_drop(m), self.pos_drop(l)

        # multi-layer fusion; each block returns fused tokens for the "working stream"
        fused = s
        for blk in self.blocks:
            fused = blk(fused, m, l)

        return fused, (Hs, Ws)

    def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None, mu: float = 0.1):
        """
        Classification forward. If labels are provided, also computes the CE - mu*FHGR loss.
        Returns:
            logits or (logits, loss) if labels is not None
        """
        fused, _ = self.forward_backbone(x)
        cls_token = fused.mean(dim=1)          # global average pooling in token space
        logits = self.head(self.norm(cls_token))

        if labels is None:
            return logits

        # classification loss
        if labels.ndim == 2:  # multi-label (e.g., COCO)
            ce = F.binary_cross_entropy_with_logits(logits, labels.float())
        else:
            ce = F.cross_entropy(logits, labels.long())

        # small auxiliary F-HGR loss across token pairs inside fused stream
        # (encourages stable global/local correlation)
        # use half the tokens as Q and the other half as K just to construct a correlation signal
        B, N, D = fused.shape
        Q, K = fused[:, :N//2, :], fused[:, N//2:, :]
        aux = fast_hgr_logits(Q, K, lam=0.1).mean()  # scalar

        loss = ce - mu * aux
        return logits, loss

    @torch.no_grad()
    def forward_features_map(self, x: torch.Tensor) -> torch.Tensor:
        """
        Return a (B,C,Hs,Ws) feature map derived from the fused token stream
        by reshaping back to small-branch spatial size.
        """
        fused, (Hs, Ws) = self.forward_backbone(x)
        B, N, D = fused.shape
        assert N == Hs * Ws, "Token count does not match small-branch spatial size"
        fmap = fused.transpose(1, 2).contiguous().view(B, D, Hs, Ws)
        return fmap

    def forward_seg(self, x: torch.Tensor, out_size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        """
        Dense segmentation forward. Produces (B, num_classes, H, W) logits.
        """
        fmap = self.forward_features_map(x)       # (B, D, Hs, Ws)
        logits = self.seg_head(fmap)              # (B, C, Hs, Ws)
        if out_size is not None:
            logits = F.interpolate(logits, size=out_size, mode='bilinear', align_corners=False)
        return logits


# -----------------------------
# builders
# -----------------------------

def mk_cavit_tiny(num_classes: int = 1000, img_size: int = 224) -> MK_CAViT:
    return MK_CAViT(num_classes=num_classes, img_size=img_size,
                    embed_dim=64, depth=4, heads=4, drop_path=0.05)

def mk_cavit_small(num_classes: int = 1000, img_size: int = 224) -> MK_CAViT:
    return MK_CAViT(num_classes=num_classes, img_size=img_size,
                    embed_dim=96, depth=6, heads=6, drop_path=0.1)

def mk_cavit_base(num_classes: int = 1000, img_size: int = 224) -> MK_CAViT:
    return MK_CAViT(num_classes=num_classes, img_size=img_size,
                    embed_dim=128, depth=8, heads=8, drop_path=0.15)
